{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementation of Sparse Mixtures-of-Experts layer in PyTorch from Mistral Official Repo\n",
    "\n",
    "📌 https://github.com/mistralai/mistral-src/blob/main/mistral/moe.py\n",
    "\n",
    "And its super simple.\n",
    "\n",
    "--------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from simple_parsing.helpers import Serializable\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class MoeArgs(Serializable):\n",
    "    num_experts: int\n",
    "    num_experts_per_tok: int\n",
    "\n",
    "\n",
    "class MoeLayer(nn.Module):\n",
    "    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):\n",
    "        super().__init__()\n",
    "        assert len(experts) > 0\n",
    "        self.experts = nn.ModuleList(experts)\n",
    "        self.gate = gate\n",
    "        self.args = moe_args\n",
    "\n",
    "    def forward(self, inputs: torch.Tensor):\n",
    "        gate_logits = self.gate(inputs)\n",
    "        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)\n",
    "        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)\n",
    "        results = torch.zeros_like(inputs)\n",
    "        for i, expert in enumerate(self.experts):\n",
    "            batch_idx, nth_expert = torch.where(selected_experts == i)\n",
    "            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(\n",
    "                inputs[batch_idx]\n",
    "            )\n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "📌 `torch.topk()` is used over the gate outputs to find the best expert per training example. It computes the top `num_experts_per_tok` logits for each token across the expert dimension. This operation returns two tensors: the top logits (`weights`) and their corresponding expert indices (`selected_experts`). \n",
    "\n",
    "📌 Then `torch.where()` to determine which training examples in the batch should be routed to which expert and so uses `selected_experts` to map each token to its allocated experts. The gating mechanism's sparsity is embodied here, as each token is only routed to a limited set of experts (as defined by `num_experts_per_tok`), rather than all available experts.\n",
    "\n",
    "📌 `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
    "- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
    "- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
    "\n",
    "📌 The softmax applied to `weights` normalizes these logits, converting them into a probability distribution over the selected experts for each token. This step ensures that the contribution of each selected expert is weighted proportionally to its predicted relevance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----------------\n",
    "\n",
    "More Explanations on the 2 key steps 🔽\n",
    "\n",
    "\n",
    "📌 `torch.topk()` returns the `k` largest elements from the given input tensor along a specified dimension. The function returns two tensors: the first contains the top `k` values, and the second contains the indices of these values in the tensor.\n",
    "\n",
    "Here,\n",
    "\n",
    "- `gate_logits` represents the output from the gating mechanism, which is essentially the scores or logits indicating how much each training example is relevant to each expert.\n",
    "\n",
    "- `torch.topk(gate_logits, self.args.num_experts_per_tok)` finds the top `k` experts (where `k` is `self.args.num_experts_per_tok`) for each token or training example. The returned values are:\n",
    "    \n",
    "    - `weights`: The scores or probabilities of each of the top `k` experts (i.e. the gate logits)\n",
    "    - `selected_experts`: The indices of these top `k` experts.\n",
    "\n",
    "📌 The `torch.topk` function, by default, operates on the last dimension of the input tensor unless otherwise specified by the `dim` argument. Since `gate_logits` is not explicitly reshaped or permuted in the code before the `topk` call, and the `dim` argument is not provided, it is logical to deduce that the operation is performed across the expert dimension, which is the last dimension in the `gate_logits` tensor. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "------------\n",
    "\n",
    "📌 `torch.where()` is used for conditional selection of elements from tensors. The function's signature is `torch.where(condition, x, y)`. It takes three arguments:\n",
    "\n",
    "- **condition**: A boolean tensor. The shape of the condition tensor dictates the shape of the output.\n",
    "- **x**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `True`.\n",
    "- **y**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `False`.\n",
    "\n",
    "📌 The output tensor is formed by selecting elements from `x` or `y` based on the `condition`. If `condition[i, j, ...] == True`, the output at that location is `x[i, j, ...]`; otherwise, it is `y[i, j, ...]`. \n",
    "\n",
    "📌 In this implementation here, `torch.where()` is used differently. It's used to find indices where a condition is true.\n",
    "\n",
    "📌 Here, `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
    "- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
    "- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
    "\n",
    "📌 These indices (`batch_idx` and `nth_expert`) are then used to route the inputs to the appropriate expert in the Mixture of Experts (MoE) layer. For each expert `i`, it finds which inputs (`inputs[batch_idx]`) should be processed by that expert. The results are scaled by the corresponding weights (`weights[batch_idx, nth_expert, None]`) and accumulated in `results`."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
